#
# Copyright © 2024 Agora
# This file is part of TEN Framework, an open source project.
# Licensed under the Apache License, Version 2.0, with certain conditions.
# Refer to the "LICENSE" file in the root directory for more information.
#
from unittest.mock import patch, AsyncMock
import asyncio
import filecmp
import json
import os
import shutil
import threading

from ten_runtime import (
    ExtensionTester,
    TenEnvTester,
    Data,
)
from ten_ai_base.struct import TTSTextInput, TTSFlush
from ..tencent_tts import (
    MESSAGE_TYPE_PCM,
    MESSAGE_TYPE_CMD_COMPLETE,
    MESSAGE_TYPE_CMD_METRIC,
)


# ================ test dump file functionality ================
class ExtensionTesterDump(ExtensionTester):
    def __init__(self):
        super().__init__()
        # Use a fixed path as requested by the user.
        self.dump_dir = "./dump/"
        # Use a unique name for the file generated by the test to avoid collision
        # with the file generated by the extension.
        self.test_dump_file_path = os.path.join(
            self.dump_dir, "test_manual_dump.pcm"
        )
        self.audio_end_received = False
        self.received_audio_chunks = []

    def on_start(self, ten_env_tester: TenEnvTester) -> None:
        """Called when test starts, sends a TTS request."""
        ten_env_tester.log_info("Dump test started, sending TTS request.")

        tts_input = TTSTextInput(
            request_id="tts_request_1",
            text="hello word, hello agora",
        )
        data = Data.create("tts_text_input")
        data.set_property_from_json(None, tts_input.model_dump_json())
        ten_env_tester.send_data(data)
        ten_env_tester.on_start_done()

    def on_data(self, ten_env: TenEnvTester, data) -> None:
        name = data.get_name()
        if name == "tts_audio_end":
            ten_env.log_info("Received tts_audio_end, stopping test.")
            self.audio_end_received = True
            ten_env.stop_test()

    def on_audio_frame(self, ten_env: TenEnvTester, audio_frame):
        """Receives audio frames and collects their data using the lock/unlock pattern."""
        # The 'audio_frame' object is a wrapper around a memory buffer.
        # We must lock the buffer to safely access the data, copy it,
        # and finally unlock the buffer so the runtime can reuse it.
        buf = audio_frame.lock_buf()
        try:
            # We must copy the data from the buffer, as the underlying memory
            # may be freed or reused after we unlock it.
            copied_data = bytes(buf)
            self.received_audio_chunks.append(copied_data)
        finally:
            # Always ensure the buffer is unlocked, even if an error occurs.
            audio_frame.unlock_buf(buf)

    def write_test_dump_file(self):
        """Writes the collected audio chunks to a file."""
        with open(self.test_dump_file_path, "wb") as f:
            for chunk in self.received_audio_chunks:
                f.write(chunk)

    def find_tts_dump_file(self) -> str | None:
        """Find the dump file created by the TTS extension in the fixed dump directory."""
        if not os.path.exists(self.dump_dir):
            return None
        for filename in os.listdir(self.dump_dir):
            if filename.endswith(".pcm") and filename != os.path.basename(
                self.test_dump_file_path
            ):
                return os.path.join(self.dump_dir, filename)
        return None


@patch("tencent_tts_python.extension.TencentTTSClient")
def test_dump_functionality(MockTencentTTSClient):
    """Tests that the dump file from the TTS extension matches the audio received by the test extension."""

    print("Starting test_dump_functionality with mock...")

    # --- Directory Setup ---
    # As requested, use a fixed './dump/' directory.
    DUMP_PATH = "./dump/"

    # Clean up directory before the test, in case of previous failed runs.
    if os.path.exists(DUMP_PATH):
        shutil.rmtree(DUMP_PATH)
    os.makedirs(DUMP_PATH)

    # --- Mock Configuration ---
    mock_instance = MockTencentTTSClient.return_value
    mock_instance.start = AsyncMock()
    mock_instance.stop = AsyncMock()

    # Create some fake audio data to be streamed
    fake_audio_chunk_1 = b"\x11\x22\x33\x44" * 20
    fake_audio_chunk_2 = b"\xaa\xbb\xcc\xdd" * 20

    # Mock synthesize_audio and get_audio_data with proper timing using asyncio.Queue
    audio_queue = asyncio.Queue()

    async def mock_synthesize_audio(text: str, text_input_end: bool):
        # Add audio data to queue when synthesis starts
        audio_queue.put_nowait((False, MESSAGE_TYPE_CMD_METRIC, 200))
        audio_queue.put_nowait((False, MESSAGE_TYPE_PCM, fake_audio_chunk_1))
        audio_queue.put_nowait((False, MESSAGE_TYPE_PCM, fake_audio_chunk_2))
        audio_queue.put_nowait((True, MESSAGE_TYPE_CMD_COMPLETE, b""))
        print(
            f"Mock synthesize_audio called with text: {text}, text_input_end: {text_input_end}"
        )

    async def mock_get_audio_data():
        return await audio_queue.get()

    mock_instance.synthesize_audio.side_effect = mock_synthesize_audio
    mock_instance.get_audio_data.side_effect = mock_get_audio_data

    # --- Test Setup ---
    tester = ExtensionTesterDump()

    dump_config = {
        "dump": True,
        "dump_path": DUMP_PATH,
        "params": {
            "app_id": "1234567890",
            "secret_id": "test_secret_id",
            "secret_key": "test_secret_key",
            "sample_rate": 24000,
            "voice_type": 0,
        },
    }

    tester.set_test_mode_single("tencent_tts_python", json.dumps(dump_config))

    try:
        print("Running dump test...")
        tester.run()
        print("Dump test completed.")

        # --- Assertions ---
        assert tester.audio_end_received, "tts_audio_end was not received"

        # Write the audio chunks collected by the test extension to its own dump file
        tester.write_test_dump_file()
        assert os.path.exists(
            tester.test_dump_file_path
        ), "Test dump file was not created"

        # Find the dump file automatically created by the TTS extension
        tts_dump_file = tester.find_tts_dump_file()
        assert (
            tts_dump_file is not None
        ), f"Could not find TTS-generated dump file in {DUMP_PATH}"

        print(
            f"Comparing TTS dump file: {tts_dump_file}, file size: {os.path.getsize(tts_dump_file)}"
        )
        print(
            f"With test dump file:    {tester.test_dump_file_path}, file size: {os.path.getsize(tester.test_dump_file_path)}"
        )

        # Binary comparison of the two files
        assert filecmp.cmp(
            tts_dump_file, tester.test_dump_file_path, shallow=False
        ), "The TTS dump file and the test-generated dump file do not match."

        print("✅ Dump file binary comparison passed.")

    finally:
        # Cleanup the dump directory after the test.
        if os.path.exists(DUMP_PATH):
            shutil.rmtree(DUMP_PATH)


# ================ test text_input_end logic ================
class ExtensionTesterTextInputEnd(ExtensionTester):
    def __init__(self):
        super().__init__()
        self.ten_env: TenEnvTester | None = None
        self.first_request_audio_end_received = False
        self.second_request_error_received = False
        self.error_code = None
        self.error_message = None
        self.error_module = None

    def on_start(self, ten_env_tester: TenEnvTester) -> None:
        self.ten_env = ten_env_tester
        ten_env_tester.log_info(
            "TextInputEnd test started, sending first TTS request."
        )

        # 1. Send first request with text_input_end=True
        tts_input_1 = TTSTextInput(
            request_id="tts_request_1",
            text="hello word, hello agora",
            text_input_end=True,
        )
        data = Data.create("tts_text_input")
        data.set_property_from_json(None, tts_input_1.model_dump_json())
        ten_env_tester.send_data(data)
        ten_env_tester.on_start_done()

    def send_second_request(self):
        """Sends the second TTS request that should be ignored."""
        if self.ten_env is None:
            return

        self.ten_env.log_info("Sending second TTS request, expecting an error.")
        # 2. Send second request with text_input_end=False
        tts_input_2 = TTSTextInput(
            request_id="tts_request_1",
            text="this should be ignored",
            text_input_end=False,
        )
        data = Data.create("tts_text_input")
        data.set_property_from_json(None, tts_input_2.model_dump_json())
        self.ten_env.send_data(data)

    def on_data(self, ten_env: TenEnvTester, data) -> None:
        name = data.get_name()
        ten_env.log_info(f"Received data: {name}")

        if name == "tts_audio_end":
            if not self.first_request_audio_end_received:
                ten_env.log_info(
                    "Received tts_audio_end for the first request."
                )
                self.first_request_audio_end_received = True
                self.send_second_request()
            return

        json_str, _ = data.get_property_to_json(None)
        ten_env.log_info(f"Received data: {json_str}")

        if not json_str:
            return

        payload = json.loads(json_str)
        request_id = payload.get("id")

        if name == "error" and request_id == "tts_request_1":
            ten_env.log_info(
                f"Received expected error for the second request: {payload}"
            )
            self.second_request_error_received = True
            self.error_code = payload.get("code")
            self.error_message = payload.get("message")
            self.error_module = payload.get("module")
            ten_env.stop_test()


@patch("tencent_tts_python.extension.TencentTTSClient")
def test_text_input_end_logic(MockTencentTTSClient):
    """
    Tests that after a request with text_input_end=True is processed,
    subsequent requests with the same request_id and text_input_end=False are ignored and trigger an error.
    """
    print("Starting test_text_input_end_logic with mock...")

    # --- Mock Configuration ---
    mock_instance = MockTencentTTSClient.return_value
    mock_instance.start = AsyncMock()
    mock_instance.stop = AsyncMock()

    # Create some fake audio data to be streamed
    fake_audio_chunk_1 = b"\x11\x22\x33\x44" * 20
    fake_audio_chunk_2 = b"\xaa\xbb\xcc\xdd" * 20

    # Mock synthesize_audio and get_audio_data with proper timing using asyncio.Queue
    audio_queue = asyncio.Queue()

    async def mock_synthesize_audio(text: str, text_input_end: bool):
        # Add audio data to queue when synthesis starts
        audio_queue.put_nowait((False, MESSAGE_TYPE_CMD_METRIC, 200))
        audio_queue.put_nowait((False, MESSAGE_TYPE_PCM, fake_audio_chunk_1))
        audio_queue.put_nowait((False, MESSAGE_TYPE_PCM, fake_audio_chunk_2))
        audio_queue.put_nowait((True, MESSAGE_TYPE_CMD_COMPLETE, b""))
        print(
            f"Mock synthesize_audio called with text: {text}, text_input_end: {text_input_end}"
        )

    async def mock_get_audio_data():
        return await audio_queue.get()

    mock_instance.synthesize_audio.side_effect = mock_synthesize_audio
    mock_instance.get_audio_data.side_effect = mock_get_audio_data

    # --- Test Setup ---
    config = {
        "params": {
            "app_id": "1234567890",
            "secret_id": "test_secret_id",
            "secret_key": "test_secret_key",
            "sample_rate": 24000,
            "voice_type": 0,
        }
    }

    tester = ExtensionTesterTextInputEnd()
    tester.set_test_mode_single("tencent_tts_python", json.dumps(config))

    print("Running text_input_end logic test...")
    tester.run()
    print("text_input_end logic test completed.")

    # --- Assertions ---
    assert (
        tester.first_request_audio_end_received
    ), "Did not receive tts_audio_end for the first request."
    assert (
        tester.second_request_error_received
    ), "Did not receive the expected error for the second request."
    assert (
        tester.error_code == 1000
    ), f"Expected error code 1000, but got {tester.error_code}"

    print("✅ Text input end logic test passed successfully.")


# ================ test flush logic ================
class ExtensionTesterFlush(ExtensionTester):
    def __init__(self):
        super().__init__()
        self.ten_env: TenEnvTester | None = None
        self.audio_start_received = False
        self.first_audio_frame_received = False
        self.flush_start_received = False
        self.audio_end_received = False
        self.flush_end_received = False
        self.audio_end_reason = ""
        self.total_audio_duration_from_event = 0
        self.received_audio_bytes = 0
        self.sample_rate = 24000
        self.bytes_per_sample = 2  # 16-bit
        self.channels = 1
        self.audio_received_after_flush_end = False

    def on_start(self, ten_env_tester: TenEnvTester) -> None:
        self.ten_env = ten_env_tester
        ten_env_tester.log_info("Flush test started, sending long TTS request.")
        tts_input = TTSTextInput(
            request_id="tts_request_for_flush",
            text="This is a very long text designed to generate a continuous stream of audio, providing enough time to send a flush command.",
        )
        data = Data.create("tts_text_input")
        data.set_property_from_json(None, tts_input.model_dump_json())
        ten_env_tester.send_data(data)
        ten_env_tester.on_start_done()

    def on_audio_frame(self, ten_env: TenEnvTester, audio_frame):
        if self.flush_end_received:
            ten_env.log_error("Received audio frame after tts_flush_end!")
            self.audio_received_after_flush_end = True

        if not self.first_audio_frame_received:
            self.first_audio_frame_received = True
            ten_env.log_info("First audio frame received, sending flush data.")
            flush_data = Data.create("tts_flush")
            flush_data.set_property_from_json(
                None,
                TTSFlush(flush_id="tts_request_for_flush").model_dump_json(),
            )
            ten_env.send_data(flush_data)

        buf = audio_frame.lock_buf()
        try:
            self.received_audio_bytes += len(buf)
        finally:
            audio_frame.unlock_buf(buf)

    def on_data(self, ten_env: TenEnvTester, data) -> None:
        name = data.get_name()
        ten_env.log_info(f"on_data name: {name}")

        if name == "tts_audio_start":
            self.audio_start_received = True
            return

        if name == "tts_flush_start":
            self.flush_start_received = True
            return

        json_str, _ = data.get_property_to_json(None)
        if not json_str:
            return
        payload = json.loads(json_str)
        ten_env.log_info(f"on_data payload: {payload}")

        if name == "tts_audio_end":
            self.audio_end_received = True
            self.audio_end_reason = payload.get("reason")
            self.total_audio_duration_from_event = payload.get(
                "request_total_audio_duration_ms"
            )

        elif name == "tts_flush_end":
            self.flush_end_received = True

            def stop_test_later():
                ten_env.log_info("Waited after flush_end, stopping test now.")
                ten_env.stop_test()

            # Use threading.Timer to avoid 'no running event loop' error,
            # as on_data is called from a non-async context.
            timer = threading.Timer(0.5, stop_test_later)
            timer.start()

    def get_calculated_audio_duration_ms(self) -> int:
        duration_sec = self.received_audio_bytes / (
            self.sample_rate * self.bytes_per_sample * self.channels
        )
        return int(duration_sec * 1000)


@patch("tencent_tts_python.extension.TencentTTSClient")
def test_flush_logic(MockTencentTTSClient):
    """
    Tests that sending a flush command during TTS streaming correctly stops
    the audio and sends the appropriate events.
    """
    print("Starting test_flush_logic with mock...")

    mock_instance = MockTencentTTSClient.return_value
    mock_instance.start = AsyncMock()
    mock_instance.stop = AsyncMock()

    # Create some fake audio data to be streamed
    fake_audio_chunk_1 = b"\x11\x22\x33\x44" * 20
    fake_audio_chunk_2 = b"\xaa\xbb\xcc\xdd" * 20

    # Mock synthesize_audio and get_audio_data with proper timing using asyncio.Queue
    audio_queue = asyncio.Queue()

    async def mock_synthesize_audio(text: str, text_input_end: bool):
        # Add audio data to queue when synthesis starts
        audio_queue.put_nowait((False, MESSAGE_TYPE_CMD_METRIC, 200))
        audio_queue.put_nowait((False, MESSAGE_TYPE_PCM, fake_audio_chunk_1))
        audio_queue.put_nowait((False, MESSAGE_TYPE_PCM, fake_audio_chunk_2))
        audio_queue.put_nowait((True, MESSAGE_TYPE_CMD_COMPLETE, b""))
        print(
            f"Mock synthesize_audio called with text: {text}, text_input_end: {text_input_end}"
        )

    async def mock_get_audio_data():
        return await audio_queue.get()

    mock_instance.synthesize_audio.side_effect = mock_synthesize_audio
    mock_instance.get_audio_data.side_effect = mock_get_audio_data

    config = {
        "params": {
            "app_id": "1234567890",
            "secret_id": "test_secret_id",
            "secret_key": "test_secret_key",
            "sample_rate": 24000,
            "voice_type": 0,
        }
    }
    tester = ExtensionTesterFlush()
    tester.set_test_mode_single("tencent_tts_python", json.dumps(config))

    print("Running flush logic test...")
    tester.run()
    print("Flush logic test completed.")

    assert tester.audio_start_received, "Did not receive tts_audio_start."
    assert tester.first_audio_frame_received, "Did not receive any audio frame."
    assert tester.audio_end_received, "Did not receive tts_audio_end."
    assert tester.flush_end_received, "Did not receive tts_flush_end."
    assert (
        not tester.audio_received_after_flush_end
    ), "Received audio after tts_flush_end."

    calculated_duration = tester.get_calculated_audio_duration_ms()
    event_duration = tester.total_audio_duration_from_event
    print(
        f"calculated_duration: {calculated_duration}, event_duration: {event_duration}"
    )
    assert (
        abs(calculated_duration - event_duration) < 10
    ), f"Mismatch in audio duration. Calculated: {calculated_duration}ms, From event: {event_duration}ms"

    print("✅ Flush logic test passed successfully.")
